{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "from transformers import PreTrainedTokenizerFast, PhiForCausalLM, TrainingArguments, Trainer, TrainerCallback\n",
    "from datasets import load_dataset\n",
    "import pandas as pd\n",
    "import time\n",
    "import torch\n",
    "from trl import DataCollatorForCompletionOnlyLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. 定义训练数据，tokenizer，预训练模型的路径及最大长度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sft_file = './data/sft_train_data.parquet'\n",
    "tokenizer_dir = './model_save/tokenizer/'\n",
    "sft_from_checkpoint_file = './model_save/pre/'\n",
    "model_save_dir = './model_save/sft/'\n",
    "max_seq_len = 320"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. 加载训练数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(path='parquet', data_files=sft_file, split='train', cache_dir='.cache')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['instruction', 'output'],\n",
       "    num_rows: 726475\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'instruction': ['为我寻找5个值得信任的汽车保险公司', 'A公司去年亏损了500万美元，今年净利润增长了50%，今年的净利润是多少？'], 'output': ['1. State Farm\\n2. Geico\\n3. Allstate\\n4. Progressive\\n5. Farmers Insurance', '今年净利润为750万美元']}\n"
     ]
    }
   ],
   "source": [
    "samples = dataset[0:2]\n",
    "print(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vicab size: 35840\n"
     ]
    }
   ],
   "source": [
    "tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)\n",
    "print(f\"vicab size: {len(tokenizer)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1 定义sft data_collator的指令字符\n",
    "注释掉的这段代码是手动将`instruction_template_ids`和`response_template_ids`添加到input_ids中的，因为如果是byte level tokenizer可能将`:`和后面的字符合并，导致找不到`instruction_template_ids`和`response_template_ids`。 \n",
    "\n",
    "也可以像下文一样通过在`'#'`和`':'`前后手动加`'\\n'`解决"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "instruction_template = \"##提问:\"\n",
    "response_template = \"##回答:\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 注释掉的这段代码是手动将`instruction_template_ids`和`response_template_ids`添加到input_ids中\n",
    "\n",
    "# template_ids = tokenizer([instruction_template, response_template])['input_ids']\n",
    "# instruction_template_ids, response_template_ids = template_ids[0], template_ids[1]\n",
    "# print(instruction_template_ids, response_template_ids)\n",
    "# def formatting_prompts_func(example: list[dict]) -> list[str]:\n",
    "#     batch_prompt,  batch_response = [], []\n",
    "#     n = len(example['instruction'])\n",
    "#     for i in range(n):\n",
    "#         batch_prompt.append(example['instruction'][i])\n",
    "#         batch_response.append(example['output'][i])\n",
    "        \n",
    "#     prompt_ids = tokenizer(batch_prompt, return_attention_mask=False)['input_ids']\n",
    "#     resopnse_ids = tokenizer(batch_response, return_attention_mask=False)['input_ids']\n",
    "\n",
    "#     input_ids = []\n",
    "#     for i in range(n):\n",
    "#         cur_input_ids = [tokenizer.bos_token_id] + instruction_template_ids + prompt_ids[i] \\\n",
    "#                         + response_template_ids + resopnse_ids[i] + [tokenizer.eos_token_id]\n",
    "#         input_ids.append(cur_input_ids)\n",
    "    \n",
    "#     return {'input_ids': input_ids}\n",
    "\n",
    "# from typing import List, Union\n",
    "\n",
    "\n",
    "# class Phi2DataCollatorForCompletionOnlyLM(DataCollatorForCompletionOnlyLM):\n",
    "#     def __init__(self, response_template: str | List[int], instruction_template: str | List[int] = None, *args, mlm: bool = False, ignore_index: int = -100, **kwargs):\n",
    "#         super().__init__(response_template, instruction_template, *args, mlm=mlm, ignore_index=ignore_index, **kwargs)\n",
    "    \n",
    "#     def __call__(self, features, return_tensors=None):\n",
    "#         '''\n",
    "#         执行formatting_prompts_func map后，dataset的__getitem__方法返回的是batch_size个input_ids\n",
    "#         '''\n",
    "#         batch_size = len(features)\n",
    "#         paded_input_ids = tokenizer.pad(\n",
    "#             {'input_ids': features['input_ids']},\n",
    "#             padding=True,\n",
    "#             return_attention_mask=False,\n",
    "#         )['input_ids']\n",
    "\n",
    "#         data = []\n",
    "#         for i in range(batch_size):\n",
    "#             data.append(\n",
    "#                 {'input_ids': }\n",
    "#             )\n",
    "\n",
    "#         # 最后让父类执行LM mask即可\n",
    "#         return super().__call__(data, return_tensors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [[9, 9, 11273, 32, 177, 4311, 1903, 27, 256, 2502, 2859, 7219, 17430, 177, 9, 9, 823, 32, 177, 23, 20, 8526, 4380, 3994, 20602, 177, 24, 20, 4023, 75, 1834, 85, 177, 25, 20, 15415, 15426, 4380, 177, 26, 20, 8498, 10559, 4955, 6359, 177, 27, 20, 3994, 20602, 2913, 9833, 89, 2766, 9435, 1], [9, 9, 11273, 32, 177, 39, 747, 8464, 22851, 283, 2848, 7597, 221, 6226, 9444, 14617, 1489, 5938, 12885, 9444, 3032, 404, 177, 9, 9, 823, 32, 177, 6226, 9444, 276, 31674, 7597, 1]]}\n"
     ]
    }
   ],
   "source": [
    "def batched_formatting_prompts_func(example: list[dict]) -> list[str]:\n",
    "    batch_txt = []\n",
    "    for i in range(len(example['instruction'])):\n",
    "        text = f\"{instruction_template}\\n{example['instruction'][i]}\\n{response_template}\\n{example['output'][i]}[EOS]\"\n",
    "        batch_txt.append(text)\n",
    "\n",
    "    # token to id \n",
    "    input_ids = tokenizer(batch_txt, return_attention_mask=False)['input_ids']\n",
    "\n",
    "    return {'input_ids': input_ids}\n",
    "\n",
    "print(batched_formatting_prompts_func(samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.map(batched_formatting_prompts_func, batched=True, remove_columns=dataset.column_names).shuffle(23333)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 定义data_collator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mlm=False表示训练的是CLM模型\n",
    "data_collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. 加载预训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Phi2 size: 193.65M parameters\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model = PhiForCausalLM.from_pretrained(sft_from_checkpoint_file)\n",
    "\n",
    "model_size = sum(t.numel() for t in model.parameters())\n",
    "print(f\"Phi2 size: {model_size / 1000**2:.2f}M parameters\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义训练过程中的回调函数\n",
    "N次log之后情况cuda缓存，能有效缓解低显存机器显存缓慢增长的问题"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EmptyCudaCacheCallback(TrainerCallback):\n",
    "    log_cnt = 0\n",
    "    def on_log(self, args, state, control, logs=None, **kwargs):\n",
    "        self.log_cnt += 1\n",
    "        if self.log_cnt % 5 == 0:\n",
    "            torch.cuda.empty_cache()\n",
    "            \n",
    "empty_cuda_cahce = EmptyCudaCacheCallback()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5. 定义训练参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=model_save_dir,\n",
    "    per_device_train_batch_size=8,\n",
    "    gradient_accumulation_steps=8,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.1,\n",
    "    warmup_steps=1000,\n",
    "    learning_rate=5e-5,\n",
    "    save_steps=2000,\n",
    "    save_total_limit=3,\n",
    "    report_to='tensorboard',\n",
    "    optim=\"adafactor\",\n",
    "    bf16=True,\n",
    "    logging_steps=10,\n",
    "    log_level='info',\n",
    "    logging_first_step=True,\n",
    ")\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=dataset,\n",
    "    callbacks=[empty_cuda_cahce]\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 6. 开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train(\n",
    "    # resume_from_checkpoint=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 7. 保存日志和模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to ./model_save/sft/\n",
      "Configuration saved in ./model_save/sft/config.json\n",
      "Configuration saved in ./model_save/sft/generation_config.json\n",
      "Model weights saved in ./model_save/sft/pytorch_model.bin\n",
      "tokenizer config file saved in ./model_save/sft/tokenizer_config.json\n",
      "Special tokens file saved in ./model_save/sft/special_tokens_map.json\n"
     ]
    }
   ],
   "source": [
    "loss_log = pd.DataFrame(trainer.state.log_history)\n",
    "loss_log.to_csv(f\"./logs/sft_train_log_{time.strftime('%Y%m%d-%H%M')}.csv\")\n",
    "\n",
    "\n",
    "trainer.save_model(model_save_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
